import librosa
import math
import os
import re
import random
import numpy as np


class GenreFeatureData:

    "Music audio features for genre classification"
    hop_length = None
    genre_list = ['classical','hiphop','jazz','metal','pop','reggae',]
    # genre_list = ['blues','classical','country','disco','hiphop','jazz','metal','pop','reggae','rock' ]

    dir_trainfolder = "./gtzan/_train"
    dir_devfolder = "./gtzan/_validation"
    dir_testfolder = "./gtzan/_test"
    dir_all_files = "./gtzan/"
    # dir_all_files = "/home/young/Myfile/musicgenreclassification/final_paper/DeepAudioClassification/Data/Raw"

    train_X_preprocessed_data = "./gtzan/data_train_input.npy"
    train_Y_preprocessed_data = "./gtzan/data_train_target.npy"
    dev_X_preprocessed_data = "./gtzan/data_validation_input.npy"
    dev_Y_preprocessed_data = "./gtzan/data_validation_target.npy"
    test_X_preprocessed_data = "./gtzan/data_test_input.npy"
    test_Y_preprocessed_data = "./gtzan/data_test_target.npy"

    train_X = train_Y = None
    dev_X = dev_Y = None
    test_X = test_Y = None

    def __init__(self):
        self.hop_length = 512

        self.timeseries_length_list = []
        self.trainfiles_list = self.path_to_audiofiles(self.dir_trainfolder)
        self.devfiles_list = self.path_to_audiofiles(self.dir_devfolder)
        self.testfiles_list = self.path_to_audiofiles(self.dir_testfolder)

        self.all_files_list = []
        self.all_files_list.extend(self.trainfiles_list)
        self.all_files_list.extend(self.devfiles_list)
        self.all_files_list.extend(self.testfiles_list)

        # self.all_files_list = []
        # self.all_files_list = self.path_to_audiofiles(self.dir_all_files)
        # self.trainfiles_list, self.testfiles_list, self.devfiles_list = self.train_test_dev_split(
        #     self.all_files_list, 
        #     test_ratio=0.2,
        #     dev_ratio=0.2
        #     )
        # print("!!!!!!",self.testfiles_list)
        # compute minimum timeseries length, slow to compute, caching pre-computed value of 1290
        # self.precompute_min_timeseries_len()
        # print("min(self.timeseries_length_list) ==" + str(min(self.timeseries_length_list)))
        # self.timeseries_length = min(self.timeseries_length_list)

        self.timeseries_length = (
            128
        )   # sequence length == 128, default fftsize == 2048 & hop == 512 @ SR of 22050
        #  equals 128 overlapped windows that cover approx ~3.065 seconds of audio, which is a bit small!

    def load_preprocess_data(self):
        print("[DEBUG] total number of files: " + str(len(self.timeseries_length_list)))

        # Training set
        self.train_X, self.train_Y = self.extract_audio_features(self.trainfiles_list)
        with open(self.train_X_preprocessed_data, "wb") as f:
            np.save(f, self.train_X)
        with open(self.train_Y_preprocessed_data, "wb") as f:
            self.train_Y = self.one_hot(self.train_Y)
            np.save(f, self.train_Y)

        # Validation set
        self.dev_X, self.dev_Y = self.extract_audio_features(self.devfiles_list)
        with open(self.dev_X_preprocessed_data, "wb") as f:
            np.save(f, self.dev_X)
        with open(self.dev_Y_preprocessed_data, "wb") as f:
            self.dev_Y = self.one_hot(self.dev_Y)
            np.save(f, self.dev_Y)

        # Test set
        self.test_X, self.test_Y = self.extract_audio_features(self.testfiles_list)
        with open(self.test_X_preprocessed_data, "wb") as f:
            np.save(f, self.test_X)
        with open(self.test_Y_preprocessed_data, "wb") as f:
            self.test_Y = self.one_hot(self.test_Y)
            np.save(f, self.test_Y)

    def load_deserialize_data(self):

        self.train_X = np.load(self.train_X_preprocessed_data)
        self.train_Y = np.load(self.train_Y_preprocessed_data)

        self.dev_X = np.load(self.dev_X_preprocessed_data)
        self.dev_Y = np.load(self.dev_Y_preprocessed_data)

        self.test_X = np.load(self.test_X_preprocessed_data)
        self.test_Y = np.load(self.test_Y_preprocessed_data)

    def precompute_min_timeseries_len(self):
        for file in self.all_files_list:
            print("Loading " + str(file))
            y, sr = librosa.load(file)
            self.timeseries_length_list.append(math.ceil(len(y) / self.hop_length))

    def extract_audio_features(self, list_of_audiofiles):

        data = np.zeros(
            (len(list_of_audiofiles), self.timeseries_length, 33), dtype=np.float64
        )
        target = []
        errolist=[]
        
        for i, file in enumerate(list_of_audiofiles):
            error = 1
            try:
                y, sr = librosa.load(file)
            except:
                errolist.append(file)
                error = 0
                
            if error:
                mfcc = librosa.feature.mfcc(
                    y=y, sr=sr, hop_length=self.hop_length, n_mfcc=13
                )
                spectral_center = librosa.feature.spectral_centroid(
                    y=y, sr=sr, hop_length=self.hop_length
                )
                chroma = librosa.feature.chroma_stft(y=y, sr=sr, hop_length=self.hop_length)
                spectral_contrast = librosa.feature.spectral_contrast(
                    y=y, sr=sr, hop_length=self.hop_length
                )

                splits = re.split("[ .]", file)
                # genre = re.split("[ /]", splits[0])[-1]
                genre = re.split("[ /]", splits[1])[3]
                
                data[i, :, 0:13] = mfcc.T[0:self.timeseries_length, :]
                data[i, :, 13:14] = spectral_center.T[0:self.timeseries_length, :]
                data[i, :, 14:26] = chroma.T[0:self.timeseries_length, :]
                data[i, :, 26:33] = spectral_contrast.T[0:self.timeseries_length, :]
                target.append(genre)
                
                print(
                    "Extracted features audio track %i of %i."
                    % (i + 1, len(list_of_audiofiles))
                )
            
        error_num=len(errolist)
        target_p=np.expand_dims(np.asarray(target), axis=1)
        if error_num!=0:
            print(f"There are erros in processing data {errolist}")
            data = np.delete(data, -(error_num-1), axis=0)
        # if len(data)!=len(target)
        print(data.shape)
        print(target_p.shape)
        data = self.normalize_features(data,'standard')
        return data, target_p

    def one_hot(self, Y_genre_strings):
        y_one_hot = np.zeros((Y_genre_strings.shape[0], len(self.genre_list)))
        for i, genre_string in enumerate(Y_genre_strings):
            index = self.genre_list.index(genre_string)
            y_one_hot[i, index] = 1
        return y_one_hot

    def normalize_features(self, data, type='min-max'):
        """
        Normalize or standardize the feature data.

        Parameters:
        - data: numpy array of data to normalize or standardize
        - type: 'min-max' for Min-Max normalization, 'standard' for standardization

        Returns:
        - normalized or standardized data
        """
        # Min-Max normalization
        if type == 'min-max':
            min_vals = data.min(axis=0)  # Minimum value per feature
            max_vals = data.max(axis=0)  # Maximum value per feature
            data = (data - min_vals) / (max_vals - min_vals)
        
        # Standardization (Z-score normalization)
        elif type == 'standard':
            mean_vals = data.mean(axis=0)  # Mean value per feature
            std_dev = data.std(axis=0)     # Standard deviation per feature
            data = (data - mean_vals) / std_dev

        return data

    @staticmethod
    def path_to_audiofiles(dir_folder):
        list_of_audio = []
        for root, dirs, files in os.walk(dir_folder):
            for file in files:
                if file.endswith(".au"):
                # if file.endswith(".wav"):
                    directory = os.path.join(root, file)
                    list_of_audio.append(directory)
        return list_of_audio
    
    @staticmethod
    def train_test_dev_split(data, test_ratio, dev_ratio, random_seed=42):
        # 设置随机种子以确保可复现性
        random.seed(random_seed)
        
        # 随机打乱列表顺序
        random.shuffle(data)
        
        # 计算测试集和验证集的大小
        test_size = int(len(data) * test_ratio)
        dev_size = int(len(data) * dev_ratio)
        
        # 分割数据为训练集、测试集和验证集
        test_set = data[:test_size]
        dev_set = data[test_size:test_size + dev_size]
        train_set = data[test_size + dev_size:]
        
        return train_set, test_set, dev_set